import torch

def set_max_split_size_mb(model, max_split_size_mb):
    """
    Set the max_split_size_mb parameter in PyTorch to avoid fragmentation.
    
    Args:
        model (torch.nn.Module): The PyTorch model.
        max_split_size_mb (int): The desired value for max_split_size_mb in megabytes.
    """
    for param in model.parameters():
        param.requires_grad = False  # Disable gradient calculation to prevent unnecessary memory allocations

    # Dummy forward pass to initialize the memory allocator
    dummy_input = torch.randn(1, 1)
    model(dummy_input)

    # Get the current memory allocator state
    allocator = torch.cuda.memory._get_memory_allocator()

    # Update max_split_size_mb in the memory allocator
    allocator.set_max_split_size(max_split_size_mb * 1024 * 1024)

    for param in model.parameters():
        param.requires_grad = True  # Re-enable gradient calculation for training

# Example usage
if __name__ == "__main__":
    # Create your PyTorch model
    model = torch.nn.Linear(10, 5)

    # Set the desired max_split_size_mb value (e.g., 200 MB)
    max_split_size_mb = 200

    # Call the function to set max_split_size_mb
    set_max_split_size_mb(model, max_split_size_mb)